InĀ [1]:
import numpy as np
from matplotlib import pyplot as plt
from IPython.display import HTML
from EM import EM
from utils import gen_data, plot_single, animate_plot
InĀ [6]:
## Example 1
np.random.seed(1111)
# Generate data from a GMM with 3 clusters and 2 dimensions
k = 3
p = 2
probs = [0.3, 0.5, 0.2]
X1, true_means1, true_covs1 = gen_data(k, p, 5000, probs, lim = [-50, 50])
title = r"GMM with $k=3$ components and $p=2$ dimensions"
plot_single(X1, k, title, true_param = (true_means1, true_covs1), filename = "Examples/2D case 1/original")
K-means initialization
InĀ [7]:
mu_km1, sigma_km1, pi_km1, snapshots_km1, lls_km1 = EM(k, p, X1, 20, 1e-6, init_kmeans = True)
Initializing under K-Means...
0%| | 0/20 [00:00<?, ?it/s]
EM converged at iteration 2!
InĀ [4]:
anim = animate_plot(snapshots_km1, X1, k, true_means1, true_covs1, True, "Examples/2D case 1/kmeans")
plt.close()
HTML(anim.to_jshtml())
The PostScript backend does not support transparency; partially transparent artists will be rendered opaque. The PostScript backend does not support transparency; partially transparent artists will be rendered opaque. The PostScript backend does not support transparency; partially transparent artists will be rendered opaque. The PostScript backend does not support transparency; partially transparent artists will be rendered opaque.
Out[4]:
InĀ [6]:
# Plot log-likelihood
plt.plot(range(1, len(lls_km1)), lls_km1[1:], marker='.')
plt.title("Log-likelihood (under K-Means initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_km1)))
plt.show()
Random from data initialization
InĀ [9]:
mu_rand1, sigma_rand1, pi_rand1, snapshots_rand1, lls_rand1 = EM(k, p, X1, 20, 1e-6, init_kmeans = False)b
Initializing under random selection...
0%| | 0/20 [00:00<?, ?it/s]
InĀ [8]:
anim = animate_plot(snapshots_rand1, X1, k, true_means1, true_covs1, False, "Examples/2D case 1/random")
plt.close()
HTML(anim.to_jshtml())
Out[8]:
InĀ [9]:
# Plot log-likelihood
plt.plot(range(1, len(lls_rand1)), lls_rand1[1:], marker='.')
plt.title("Log-likelihood (under random from data initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_rand1)))
plt.show()
InĀ [10]:
print("True parameters")
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", probs[i])
print("mean: ", true_means1[i])
print("cov: ")
print(true_covs1[i])
print("----------------------------------")
print("\n-------------------------------------------------")
print("EM estimates: K-Means initialization")
print("log-likelihood:", lls_km1[-1])
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", pi_km1[i])
print("mean: ", mu_km1[i])
print("cov: ")
print(sigma_km1[i])
print("----------------------------------")
print("\n-------------------------------------------------")
print("EM estimates: random from data initialization")
print("log-likelihood:", lls_rand1[-1])
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", pi_rand1[i])
print("mean: ", mu_rand1[i])
print("cov: ")
print(sigma_rand1[i])
print("----------------------------------")
True parameters ---------------------------------- Gaussian 1 weight: 0.3 mean: [-40.44508008 42.50037019] cov: [[3.21543285 3.50252611] [3.50252611 5.50623995]] ---------------------------------- Gaussian 2 weight: 0.5 mean: [-15.64265768 -18.95230581] cov: [[4.81091237 3.29917686] [3.29917686 3.13286866]] ---------------------------------- Gaussian 3 weight: 0.2 mean: [-49.79901601 -26.44052756] cov: [[3.84424958 2.5870179 ] [2.5870179 4.24419944]] ---------------------------------- ------------------------------------------------- EM estimates: K-Means initialization log-likelihood: -23628.576273537983 ---------------------------------- Gaussian 1 weight: 0.4994 mean: [-15.68903827 -18.98941663] cov: [[4.8055495 3.26066693] [3.26066693 3.1212336 ]] ---------------------------------- Gaussian 2 weight: 0.2946 mean: [-40.46876646 42.40761474] cov: [[3.34420514 3.73352807] [3.73352807 5.84389619]] ---------------------------------- Gaussian 3 weight: 0.206 mean: [-49.70443489 -26.43634014] cov: [[3.72239938 2.63233028] [2.63233028 4.45707507]] ---------------------------------- ------------------------------------------------- EM estimates: random from data initialization log-likelihood: -23628.576279563604 ---------------------------------- Gaussian 1 weight: 0.4993998778588905 mean: [-15.68903613 -18.98941551] cov: [[4.80553185 3.26065792] [3.26065792 3.12122925]] ---------------------------------- Gaussian 2 weight: 0.20600012214110944 mean: [-49.70441992 -26.43633843] cov: [[3.72277543 2.632372 ] [2.632372 4.45707743]] ---------------------------------- Gaussian 3 weight: 0.2946 mean: [-40.46876646 42.40761474] cov: [[3.34420514 3.73352807] [3.73352807 5.84389619]] ----------------------------------
InĀ [12]:
## Example 2
np.random.seed(1)
# Generate data from a GMM with 3 clusters and 2 dimensions
k = 3
p = 2
probs = [0.3, 0.5, 0.2]
X2, true_means2, true_covs2 = gen_data(k, p, 5000, probs, lim = [-50, 50])
title = r"GMM with $k=3$ components and $p=2$ dimensions"
plot_single(X2, k, title, true_param = (true_means2, true_covs2), filename = "Examples/2D case 2/original")
K-Means initialization
InĀ [13]:
mu_km2, sigma_km2, pi_km2, snapshots_km2, lls_km2 = EM(k, p, X2, 20, 1e-6, init_kmeans = True)
Initializing under K-Means...
0%| | 0/20 [00:00<?, ?it/s]
EM converged at iteration 2!
InĀ [13]:
anim = animate_plot(snapshots_km2, X2, k, true_means2, true_covs2, True, "Examples/2D case 2/kmeans")
plt.close()
HTML(anim.to_jshtml())
Out[13]:
InĀ [14]:
# Plot log-likelihood
plt.plot(range(1, len(lls_km2)), lls_km2[1:], marker='.')
plt.title("Log-likelihood (under K-Means initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_km2)))
plt.show()
Random from data initialization
InĀ [15]:
mu_rand2, sigma_rand2, pi_rand2, snapshots_rand2, lls_rand2 = EM(k, p, X2, 20, 1e-6, init_kmeans = False)
Initializing under random selection...
0%| | 0/20 [00:00<?, ?it/s]
InĀ [16]:
anim = animate_plot(snapshots_rand2, X2, k, true_means2, true_covs2, False, "Examples/2D case 2/random")
plt.close()
HTML(anim.to_jshtml())
Out[16]:
InĀ [17]:
# Plot log-likelihood
plt.plot(range(1, len(lls_rand2)), lls_rand2[1:], marker='.')
plt.title("Log-likelihood (under random from data initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_rand2)))
plt.show()
InĀ [18]:
print("True parameters")
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", probs[i])
print("mean: ", true_means2[i])
print("cov: ")
print(true_covs2[i])
print("----------------------------------")
print("\n-------------------------------------------------")
print("EM estimates: K-Means initialization")
print("log-likelihood:", lls_km2[-1])
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", pi_km2[i])
print("mean: ", mu_km2[i])
print("cov: ")
print(sigma_km2[i])
print("----------------------------------")
print("\n-------------------------------------------------")
print("EM estimates: random from data initialization")
print("log-likelihood:", lls_rand2[-1])
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", pi_rand2[i])
print("mean: ", mu_rand2[i])
print("cov: ")
print(sigma_rand2[i])
print("----------------------------------")
True parameters ---------------------------------- Gaussian 1 weight: 0.3 mean: [-8.29779953 22.03244934] cov: [[2.99654813 3.09455209] [3.09455209 4.59233381]] ---------------------------------- Gaussian 2 weight: 0.5 mean: [-49.98856252 -19.76674274] cov: [[5.23559795 2.64251473] [2.64251473 2.50084629]] ---------------------------------- Gaussian 3 weight: 0.2 mean: [-35.32441092 -40.76614052] cov: [[2.55790586 2.53662095] [2.53662095 4.61505524]] ---------------------------------- ------------------------------------------------- EM estimates: K-Means initialization log-likelihood: -23450.980136778926 ---------------------------------- Gaussian 1 weight: 0.499 mean: [-49.97450798 -19.74930392] cov: [[5.22876693 2.80247886] [2.80247886 2.68771396]] ---------------------------------- Gaussian 2 weight: 0.294 mean: [-8.23237847 22.07043534] cov: [[2.74538431 2.77484391] [2.77484391 4.1678629 ]] ---------------------------------- Gaussian 3 weight: 0.207 mean: [-35.36250742 -40.87559902] cov: [[2.42929148 2.46242384] [2.46242384 4.62927793]] ---------------------------------- ------------------------------------------------- EM estimates: random from data initialization log-likelihood: -30203.270937562607 ---------------------------------- Gaussian 1 weight: 0.13598903413445418 mean: [-8.23224653 22.06733363] cov: [[2.74456299 2.77220801] [2.77220801 4.16881179]] ---------------------------------- Gaussian 2 weight: 0.706 mean: [-45.69025286 -25.94355758] cov: [[ 48.65465384 -61.26990066] [-61.26990066 95.74983507]] ---------------------------------- Gaussian 3 weight: 0.15801096586554553 mean: [-8.23249202 22.07310477] cov: [[2.74609114 2.7771131 ] [2.7771131 4.16703084]] ----------------------------------
InĀ [17]:
## Example 3
np.random.seed(1234)
# Generate data from a GMM with 4 clusters and 3 dimensions
k = 4
p = 3
probs = [0.2, 0.3, 0.3, 0.2]
X3, true_means3, true_covs3 = gen_data(k, p, 5000, probs, lim = [-50, 50])
K-Means initialization
InĀ [18]:
mu_km3, sigma_km3, pi_km3, snapshots_km3, lls_km3 = EM(k, p, X3, 50, 1e-6, init_kmeans = True)
Initializing under K-Means...
0%| | 0/50 [00:00<?, ?it/s]
EM converged at iteration 2!
InĀ [21]:
# Plot log-likelihood
plt.plot(range(1, len(lls_km3)), lls_km3[1:], marker='.')
plt.title("Log-likelihood (under K-Means initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_km3)))
plt.show()
Random from data initialization
InĀ [20]:
mu_rand3, sigma_rand3, pi_rand3, snapshots_rand3, lls_rand3 = EM(k, p, X3, 50, 1e-6, init_kmeans = False)
Initializing under random selection...
0%| | 0/50 [00:00<?, ?it/s]
EM converged at iteration 21!
InĀ [23]:
# Plot log-likelihood
plt.plot(range(1, len(lls_rand3)), lls_rand3[1:], marker='.')
plt.title("Log-likelihood (under random from data initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_rand3)))
plt.show()
InĀ [24]:
print("True parameters")
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", probs[i])
print("mean: ", true_means3[i])
print("cov: ")
print(true_covs3[i])
print("----------------------------------")
print("\n-------------------------------------------------")
print("EM estimates: K-Means initialization")
print("log-likelihood:", lls_km3[-1])
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", pi_km3[i])
print("mean: ", mu_km3[i])
print("cov: ")
print(sigma_km3[i])
print("----------------------------------")
print("\n-------------------------------------------------")
print("EM estimates: random from data initialization")
print("log-likelihood:", lls_rand3[-1])
print("----------------------------------")
for i in range(k):
print("Gaussian", i + 1)
print("weight: ", pi_rand3[i])
print("mean: ", mu_rand3[i])
print("cov: ")
print(sigma_rand3[i])
print("----------------------------------")
True parameters ---------------------------------- Gaussian 1 weight: 0.2 mean: [-30.84805496 12.2108771 -6.2272261 ] cov: [[4.58100941 4.02367452 2.7152851 ] [4.02367452 5.13268768 2.90279341] [2.7152851 2.90279341 2.84684934]] ---------------------------------- Gaussian 2 weight: 0.3 mean: [ 28.53585837 27.99758081 -22.74073947] cov: [[5.79663922 3.43435588 4.42211048] [3.43435588 3.50741948 2.92632845] [4.42211048 2.92632845 4.53958695]] ---------------------------------- Gaussian 3 weight: 0.3 mean: [-22.35357449 30.18721775 45.81393537] cov: [[5.14952706 3.70729243 3.91515326] [3.70729243 4.57311451 3.02505554] [3.91515326 3.02505554 4.4289102 ]] ---------------------------------- Gaussian 4 weight: 0.2 mean: [ 37.59326347 -14.218273 0.09951255] cov: [[3.57221306 2.37986685 3.98231909] [2.37986685 4.38900709 4.72353166] [3.98231909 4.72353166 7.17054494]] ---------------------------------- ------------------------------------------------- EM estimates: K-Means initialization log-likelihood: -34025.39133843038 ---------------------------------- Gaussian 1 weight: 0.2002 mean: [-30.80811576 12.23659367 -6.19080811] cov: [[4.57947439 3.90153426 2.68918687] [3.90153426 4.89931868 2.76964227] [2.68918687 2.76964227 2.92230095]] ---------------------------------- Gaussian 2 weight: 0.2042 mean: [ 37.64982941 -14.22471457 0.22265946] cov: [[3.50037576 2.39208439 3.91093861] [2.39208439 4.23480725 4.58936252] [3.91093861 4.58936252 6.97787477]] ---------------------------------- Gaussian 3 weight: 0.2914 mean: [-22.37836388 30.1563024 45.85063502] cov: [[5.17938414 3.75722081 3.89564633] [3.75722081 4.59443085 3.05729029] [3.89564633 3.05729029 4.41143479]] ---------------------------------- Gaussian 4 weight: 0.3042 mean: [ 28.50360451 28.01311823 -22.77540379] cov: [[5.85047774 3.41680497 4.46232446] [3.41680497 3.48518082 2.89449276] [4.46232446 2.89449276 4.50299112]] ---------------------------------- ------------------------------------------------- EM estimates: random from data initialization log-likelihood: -34025.39133843038 ---------------------------------- Gaussian 1 weight: 0.3042 mean: [ 28.50360451 28.01311823 -22.77540379] cov: [[5.85047774 3.41680497 4.46232446] [3.41680497 3.48518082 2.89449276] [4.46232446 2.89449276 4.50299112]] ---------------------------------- Gaussian 2 weight: 0.2914 mean: [-22.37836388 30.1563024 45.85063502] cov: [[5.17938414 3.75722081 3.89564633] [3.75722081 4.59443085 3.05729029] [3.89564633 3.05729029 4.41143479]] ---------------------------------- Gaussian 3 weight: 0.2002 mean: [-30.80811576 12.23659367 -6.19080811] cov: [[4.57947439 3.90153426 2.68918687] [3.90153426 4.89931868 2.76964227] [2.68918687 2.76964227 2.92230095]] ---------------------------------- Gaussian 4 weight: 0.2042 mean: [ 37.64982941 -14.22471457 0.22265946] cov: [[3.50037576 2.39208439 3.91093861] [2.39208439 4.23480725 4.58936252] [3.91093861 4.58936252 6.97787477]] ----------------------------------
Example 4¶
$k = 3$ clusters and $p=2$ dimensions
Here, we specify the clusters to be close to each other.
InĀ [2]:
## Example 4
np.random.seed(1111)
# Generate data from a GMM with 3 clusters and 2 dimensions
k = 3
p = 2
probs = [0.3, 0.5, 0.2]
true_means4 = np.array([[0, 0], [7, 2], [-7, -2]])
true_covs4 = np.array([[[5, 4], [4, 5]], [[3, -2], [-2, 3]], [[3, -2], [-2, 4]]])
X4 = np.zeros((5000, 2))
for i in range(5000):
z = np.random.choice(k, p = probs)
X4[i] = np.random.multivariate_normal(true_means4[z], true_covs4[z])
title = r"GMM with $k=3$ components and $p=2$ dimensions"
plot_single(X4, k, title, true_param = (true_means4, true_covs4), filename = "Examples/2D case 4/original")
K-means initialization
InĀ [3]:
mu_km4, sigma_km4, pi_km4, snapshots_km4, lls_km4 = EM(k, p, X4, 50, 1e-6, init_kmeans = True)
Initializing under K-Means...
0%| | 0/50 [00:00<?, ?it/s]
EM converged at iteration 46!
InĀ [4]:
anim = animate_plot(snapshots_km4, X4, k, true_means4, true_covs4, True, "Examples/2D case 4/kmeans")
plt.close()
HTML(anim.to_jshtml())
Out[4]:
InĀ [5]:
# Plot log-likelihood
plt.plot(range(1, len(lls_km4)), lls_km4[1:], marker='.')
plt.title("Log-likelihood (under K-Means initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_km4)))
plt.show()
Random from data initialization
InĀ [6]:
mu_rand4, sigma_rand4, pi_rand4, snapshots_rand4, lls_rand4 = EM(k, p, X4, 50, 1e-6, init_kmeans = False)
Initializing under random selection...
0%| | 0/50 [00:00<?, ?it/s]
InĀ [7]:
anim = animate_plot(snapshots_rand4, X4, k, true_means4, true_covs4, False, "Examples/2D case 4/random")
plt.close()
HTML(anim.to_jshtml())
Out[7]:
InĀ [8]:
# Plot log-likelihood
plt.plot(range(1, len(lls_rand4)), lls_rand4[1:], marker='.')
plt.title("Log-likelihood (under random from data initialization)")
plt.xlabel("Iteration")
plt.xticks(range(1, len(lls_rand4)))
plt.show()